[KDA] sm90 GVA enhance#64
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces Grouped Value Attention (GVA) support to the KDA forward prefill kernel for Hopper architectures. The changes involve decoupling Q/K head counts from V/O head counts, updating the tile scheduler to share Q/K within GVA groups, and adjusting TMA load/store logic. The Python interface was also updated with enhanced validation and a fix for the output_final_state return logic. Review feedback identifies a potential optimization to avoid unnecessary memory allocation and GPU bandwidth usage by conditionally skipping output_state initialization when the final state is not needed.
| torch::Tensor output_state = output_state_.has_value() | ||
| ? output_state_.value() | ||
| : torch::zeros( | ||
| {num_seqs, num_heads, head_size, head_size}, | ||
| {num_seqs, num_v_heads, head_size, head_size}, | ||
| torch::TensorOptions().dtype(torch::kFloat32).device(q.device())); |
There was a problem hiding this comment.
The output_state tensor is always allocated and zero-initialized even when the caller does not require the final state (i.e., when output_final_state=False in the Python API). For large models or long sequences, this results in significant unnecessary memory allocation and GPU bandwidth consumption during the kernel's write-back phase. Consider passing a flag to the kernel to skip the state store, or at least avoid the allocation in the C++ API if the result is not needed by the caller.
|
Thanks for your contribution. Could you attach the performance results with this PR? |
Okay, please wait a moment. |
|
It seems that the benchmark does not contain a GVA setting. Could you add the GVA data preparation in benchmark.utils and add the GVA benchmark? |
👌 |
|
Hi, this PR #66 removes the redundant zero init for output_final_state=False, you can merge it in your kda_sm90 interface, thanks~ |
OK |
| the effective head counts uniformly. | ||
| """ | ||
| HV = H if num_v_heads is None else num_v_heads | ||
| assert H > 0, f"H must be positive, got {H}." |
There was a problem hiding this comment.
This two assert H>0, HV>0 can be deleted
| "GVA rows (HV > H) are mixed in alongside MHA rows (HV == H) " | ||
| "under the same sequence-length settings, so GVA and MHA can be " | ||
| "compared side by side." |
There was a problem hiding this comment.
this comment can be deleted
| # GVA (HV > H) at the same (B, T) shapes for side-by-side comparison: | ||
| (1, 1024, 16, 64), # 4x | ||
| (1, 4096, 16, 64), # 4x | ||
| (1, 8192, 16, 64), # 4x | ||
| (1, 16384, 16, 64), # 4x | ||
| (1, 4096, 32, 64), # 2x | ||
| (1, 8192, 32, 64), # 2x | ||
| (1, 4096, 8, 64), # 8x | ||
| (1, 8192, 8, 64), # 8x | ||
| (2, 4096, 16, 64), # 4x | ||
| (2, 8192, 16, 64), # 4x |
There was a problem hiding this comment.
HV parameter can be specified by user with --hv.
And we have a HV parameter, so these test settings are no longer needed, just restoring to not modify them is OK.
| # Varlen configs — identical sequence-length layouts replayed with and | ||
| # without GVA so MHA and GVA can be compared row-by-row. | ||
| varlen_configs_base = build_varlen_configs( | ||
| num_seqs_list=(10, 20), | ||
| total_lens=(4096, 8192, 16384), | ||
| total_lens=(4096, 8192), | ||
| dists=("uniform", "random", "skewed"), | ||
| ) | ||
| gva_varlen_mixed = [ | ||
| (seq_lens, T, dist, H_qk, HV) | ||
| for (H_qk, HV) in ((16, 64), (32, 64)) | ||
| for (seq_lens, T, dist) in varlen_configs_base | ||
| ] | ||
| varlen_configs = list(varlen_configs_base) + gva_varlen_mixed |
| # Config normalization helpers | ||
| # ============================================================ | ||
| def _normalize_fixed_config(cfg): | ||
| """Accept either (B, T) or (B, T, H_qk, HV) and return the 4-tuple form. | ||
|
|
||
| For the 2-tuple legacy form, defaults to H_qk=HV=H (no GVA). | ||
| """ | ||
| if len(cfg) == 2: | ||
| B, T = cfg | ||
| return B, T, H, H | ||
| if len(cfg) == 4: | ||
| return cfg | ||
| raise ValueError(f"Fixed config must be (B, T) or (B, T, H, HV), got {cfg!r}") |
There was a problem hiding this comment.
these normalization helpers are no longer needed as well
| H=H, | ||
| HV=HV, |
feat(kda, sm90): add KDA GVA forward support
Summary
Add Grouped Value Attention (GVA) forward support to the KDA Hopper / SM90
prefill path: Q/K share
num_qk_heads, while V, g, β, O and the recurrentstate are sized by
num_v_heads.Constraints
num_q_heads == num_k_headsnum_v_heads >= num_qk_headsnum_v_heads % num_qk_heads == 0head_dim == 128(unchanged)heads_per_group = num_v_heads / num_qk_heads. Each value head is mapped toexactly one shared Q/K head:
When
num_v_heads == num_qk_headsthis degenerates back to standard MHA.What changed
Across the full Python → C++ → kernel stack:
cula/kda/hopper_fused_fwd.py)num_qk_headsandnum_v_heads.q,k = [B,T,H,D],v,g = [B,T,HV,D],beta = [B,T,HV],initial_state = [N,HV,D,D].output_final_state=False(the wrapper used to leak thekernel-allocated state tensor).
csrc/api/kda_sm90.cu)num_heads→num_qk_heads / num_v_heads.output [packed, HV, D]andoutput_state [N, HV, D, D].input_stateshapes againstnum_v_heads.num_qk_heads > 0,num_v_heads > 0, then% == 0,so a degenerate input never triggers division-by-zero UB.
kernel_kda_fwd.hpp,prefill_kernel*.hpp/.cuh,kda_fwd_sm90.cu,kda_fwd_sm90_safe_gate.cu)VarlenProblemShapecarries bothnum_qk_headsandnum_v_heads.qk_tok_stride = num_qk_heads * head_sizefor Q/K;
v_tok_stride = num_v_heads * head_sizefor V/O/α.csrc/kda/sm90/kernel/tile_scheduler.hpp)WorkDescnow exposesq_head_idx()/k_head_idx()returningqk_head_idx, whilev_head_idx()/o_head_idx()returnhead_idx.grid.x = num_seqs * num_v_heads; one program per(seq, v_head).heads_per_groupis computed once on the host into_underlying_argumentsand stored inParams, so the device pathavoids recomputing
num_v_heads / num_qk_headsper CTA.load_tma.hpp: Q is sliced overnum_qk_heads, K is sliced overnum_qk_heads, V/α/O are sliced overnum_v_heads. The K-vs-V branchuses a
constexprselector sinceLoadKindis a static templateparameter, so the head-count selection collapses at compile time.
store_tma.hpp: O written in the V/O head space.mainloop_kda_fwd.hpp: QK GEMM batched overnum_qk_heads; KV / α /output / state buffers all batched over
num_v_heads.Tests
Performance
Related